
import torch
import numpy as np
import device
#device = "cuda" if torch.cuda.is_available() else "cpu"
import sys
from neuralEQ import *
from eq import *
import torch.nn.utils.prune as prune

import matplotlib.pyplot as plt



class simNeuralEQ():
	'''
	Description:
		Training, Evaluating neuralEQ is implemented in this class.
		Also, berchecker for neuralEQ is implemented.
	Params:
		txDataTrain(float, list)	: Training Y set. (for Neural network output)
		rxDataTrain(float, list)	: Training X set. (for Neural network input)
		txDataTest(float, list)		: Test Y set.
		rxDataTest(float, list)		: Test X set.
		neuralEQ(class)				: Pre-defined neural network
		mod(str)					: modulation selection. (nrz / pam4 / pam8)
	'''
	def __init__(self, txDataTrain, rxDataTrain, txDataTest, rxDataTest, neuralEQ, mod):
		'''
		Description:
			Just initialize. Nothing special here.
		'''
		#super().__init__(mod=mod)
		self.txDataTrain = txDataTrain
		self.rxDataTrain = rxDataTrain
		self.txDataTest = txDataTest
		self.rxDataTest = rxDataTest
		self.neuralEQ = neuralEQ
		self.dataSize = len(txDataTest)
		if mod == 'nrz':
			self.modNum = 2
		elif mod == 'pam4':
			self.modNum = 4
		elif mod == 'pam8':
			self.modNum = 8

	def berChecker(self, ref, rxOut, delay=0, offsetStart=0, offsetEnd=1):
		'''
		Description:
			Check if rxOut(equalized output) is same with ref(tx output).
			Delay must be set properly because equalizer can add some delay to input signal.
			OffsetStart and offsetEnd can be used to ignore some unstable edge of sequences.
		Params:
			ref(float, list)	: reference is tx data. NRZ: -1,1 // PAM4: -1,-1/3,1/3,1 // PAM8: ...
			rxOut(int, list)	: equalized output. NRZ: 0,1 // PAM4: 0,1,2,3 // PAM8: 0,1,2 ... 7
			delay(int)			: delay when comparing ref with rxOut.
			offsetStart(int)	: It will ignore the ber before offsetStart index.
			offsetEnd(int)		: It will ignore the ber after (-offsetEnd) index
		'''

		#@@ Convert ref format to rxOut format. (NRZ: -1,1 -> 0,1 // PAM4: -1,-1/3,1/3,1 -> 0,1,2,3)
		if (self.modNum==2):
			amp=2.0/1
			ref = torch.where( ref < -1+amp/2, int(0),
											int(1)
											)
		elif (self.modNum==4):
			amp = 2.0/3
			ref = torch.where(	ref < -1+amp/2+amp*0, int(0),
				torch.where(	ref < -1+amp/2+amp*1, int(1),
				torch.where(	ref < -1+amp/2+amp*2, int(2), 
												int(3)
													)))
		elif (self.modNum==8):
			amp = 2.0/7
			ref = torch.where(	ref < -1+amp/2+amp*0, int(0),
				torch.where(	ref < -1+amp/2+amp*1, int(1),
				torch.where(	ref < -1+amp/2+amp*2, int(2),
				torch.where(	ref < -1+amp/2+amp*3, int(3),
				torch.where(	ref < -1+amp/2+amp*4, int(4),
				torch.where(	ref < -1+amp/2+amp*5, int(5),
				torch.where(	ref < -1+amp/2+amp*6, int(6),
													 int(7)
													)))))))

		
		#rxOut = torch.array(rxOut)
		#print(rxOut)
		rxOut = rxOut[delay:]
		length = min(len(ref),len(rxOut))
		if (len(rxOut) < len(ref)) : # overlapped viterbi case
			ref = ref[:length]
			rxOut = rxOut[:length]
		#bitErr = abs(ref-rxOut[delay:delay+len(ref)])
		bitErr = abs(ref-rxOut[:len(ref)])
		#print ("rxOut: ",rxOut)
		#print ("bitErr: ",bitErr)
		bitErr = torch.where(bitErr==0, 0, 1)
		if 0:	# print all compared bit
			print ("")
			for k in range(len(ref)):
				print (ref[k], " ", rxOut[k], " ",bitErr[k])
			print (offsetStart, " ", offsetEnd)

		bitErr = bitErr[offsetStart:-offsetEnd]
		#ber = float(sum(bitErr)/2)/len(bitErr)
		ber = float(sum(bitErr))/len(bitErr)
		#print ("bitErr: ",bitErr)
		return ber, sum(bitErr)


	def curatingData(self, rxData, txData, inSize, outSize, batchSize, delay=0, shuffle=False):
		'''
		Descriptions:
			curatingData is post-processing of rxData and txData for feeding to neural network.
			For example,
			rxData = [a0, a1, a2, a3, a4, a5, a6]
			txData = [b0, b1, b2, b3, b4, b5, b6]
			after curating with inSize=5, outSize=1, delay =2,
			rxDataSetBatch = [[a0, a1, a2, a3, a4], [a1, a2, a3, a4, a5] ... x batchSize]
			txDataSetBatch = [[b2], [b3], ... x batchSize]
			inSize must be larger than outSize
		Params:
			rxData(float, list)	: Neural network input
			txData(float, list) : Neural network output
			inSize(int)			: Neural network input size
			outSize(int)		: Neural network output size
			batchSize(int)		: Mini-batch size for training
			delay(int)			: Offset index when pairing rxData with txData.
		'''
		#@@ Initialize variables	
		rxDataBatch = []
		txDataBatch = []
		rxDataSet = []
		txDataSet = []
		rxData = list(rxData)
		txData = list(txData)
		#@@ Calculate # of total data segment according to inSize and outSize
		dataSegments = int((len(rxData)-inSize+outSize)/outSize)
		#print(f"inSize: {inSize}, outSize: {outSize}")
		#print(f"dataSegments: {dataSegments}")
		#print(f"txData: {txData[0]} ")

		for k in range(dataSegments-(len(rxData)-len(txData))):
			#@@ Segment is a unit set for feeding neural network.
			rxSegment = rxData[k*outSize:k*outSize+inSize]
			txSegment = txData[delay+k*outSize:delay+k*outSize+outSize]
			if 0:
				print (rxSegment)
				print (txSegment)

			rxDataBatch.append(rxSegment)
			txDataBatch.append(txSegment[0])
			#@@ If it reaches to batchsize, wrapping and start to create new batch.
			if (k%batchSize == batchSize-1):
				rxDataSet.append(rxDataBatch)
				txDataSet.append(txDataBatch)
				rxDataBatch = []
				txDataBatch = []
				#sys.exit()

		#print (f"rxDataSet: {rxDataSet}")
		#print (f"txDataSet: {txDataSet}")
		rxDataSet = np.array(rxDataSet)
		txDataSet = np.array(txDataSet)
		#sys.exit()
		if 0:
			print (f"rxSegment: {np.array(rxSegment)}")
			print (f"txSegment: {np.array(txSegment[0])}")

			print (f"rxSegment.shape: {np.array(rxSegment).shape}")
			print (f"txSegment.shape: {np.array(txSegment[0]).shape}")
			print (f"rxDataBatch.shape: {np.array(rxDataBatch).shape}")
			print (f"txDataBatch.shape: {np.array(txDataBatch).shape}")
			print (f"rxDataSet.shape: {rxDataSet.shape}")
			print (f"txDataSet.shape: {txDataSet.shape}")
			sys.exit()
		#print(f"len(rxDataSet) : {len(rxDataSet)}")
		#print(f"len(txDataSet) : {len(txDataSet)}")
		if shuffle:
			indices = np.arange(len(rxDataSet))
			np.random.shuffle(indices)
			rxDataSet = rxDataSet[indices]
			txDataSet = txDataSet[indices]

		return torch.Tensor(rxDataSet), torch.Tensor(txDataSet)

	def prune(self, pruneRatio):
		'''
		Description:
			Global unstructured pruning according to pruneRatio.
		Params:
			pruneRatio(float)	: Pruning ratio
		'''
		parameters_to_prune = []
		for module_name, module in self.neuralEQ.named_modules():
			if isinstance(module, torch.nn.Linear):
				parameters_to_prune.append((module, "weight"))
		prune.global_unstructured(
			parameters_to_prune,
			pruning_method=prune.L1Unstructured,
			amount=pruneRatio,
		)
		
	def trainNeuralEQ(self, lossFn, opt, inSize=8, outSize=4, batchSize=64, delay=0, plot=False,rxDataTrainNew=None, txDataTrainNew=None ): #list(self.chSBR).index(max(self.chSBR))
		'''
		Description:
			Training neural EQ with given params.
			Firstly, it generate data sets using 'curatinigData'
			Data sets are feed to neural network for training
			Loss function can be selected. It supports crossEntropy, manualCrossEntropy, mse.
		Params:
			lossFn(str)		: Loss function selection (crossEntropy / manualCrossEntropy / mse)
			opt(class)		: Optimizer
			inSize(int)		: Neural network input size
			outSize(int)	: Neural network output size
			batchSize(int)	: Mini-batch size
			delay(int)		: Offset index when pairing rxData with txData.
		'''
		## rxData size must be same or larger than txData size.

		lossList = []
		batchIdxList = []
		#@@ turn on neural net train mode
		self.neuralEQ.train()

		#@@ Override if rxDataTesNew is set.
		if (rxDataTrainNew is not None):
			(rxDataSet, txDataSet) = self.curatingData(rxDataTrainNew, txDataTrainNew, inSize, outSize, batchSize, delay=delay)

		else:
			(rxDataSet, txDataSet) = self.curatingData(self.rxDataTrain, self.txDataTrain, inSize, outSize, batchSize, delay=delay)


		#@@ Curatinig data according to given parameters

		#(rxDataSet, txDataSet) = self.curatingData(self.rxDataTrain, self.txDataTrain, inSize, outSize, batchSize, delay=delay, shuffle=True)
		#print (f"txDataSet[0]: {txDataSet[0]}")

		size = len(rxDataSet)
		#print(rxDataSet.shape)
		#print(rxDataSet)
		be = 0
		#print (f"size: {size}")

		#@@ Run training with given dataSet
		for batchIdx, tmp in enumerate(rxDataSet):
			if 0:
				print(f"batchIdxNow: {batchIdx}", flush=True)
			#for name, param in self.neuralEQ.named_parameters():
			#	print(f"name: {name} params:\n{param}")
			#@@ Extract a unit data set(x,y) from given dataSet
			x = rxDataSet[batchIdx]
			y = txDataSet[batchIdx]
			x = x.to(device.device)
			y = y.to(device.device)
			if 0:
				print(f"x.shape: {x.shape}")
			#sys.exit()
			#if (lossFn is torch.nn.CrossEntropyLoss):
			#	print(lossFn)

			#@@ Run inference
			#print(f"neuralEQ.get_device: {next(self.neuralEQ.parameters()).device}")
			#print(f"x.get_device: {x.get_device()}")
			pred = self.neuralEQ(x)
			#pred=torch.flatten(pred)
			#y=torch.flatten(y)
			#print(f"x: {x}\ny: {y}\npred: {pred}")
			#y=y.reshape(100,1,2)

			#pred = pred.reshape(batchSize,1,-1)	# seperate ver.
			#@@ sumPred NOT USED NOW
			sumPred = torch.sum(pred,axis=1)
			if 0:
				print ("")
				print (f"pred: {pred[0]}")
				print (f"y: {y[0]}")
				print (f"sum pred: {sumPred[0]}")
				#print (f"x[0]: {x[0]}")
				#print (f"y[0]: {y[0]}")

				#print (f"x[1]: {x[1]}")
				#print (f"y[1]: {y[1]}")

				#print (f"x[2]: {x[2]}")
				#print (f"y[2]: {y[2]}")

				#print (f"x[3]: {x[3]}")
				#print (f"y[3]: {y[3]}")

				#print (f"x[4]: {x[4]}")
				#print (f"y[4]: {y[4]}")

				#print (f"x[5]: {x[5]}")
				#print (f"y[5]: {y[5]}")

				print ("")
				print(f"pred.shape: {pred.shape}")
				print(f"y.shape: {y.shape}")
				print(f"sum_pred.shape: {sumPred.shape}")
				sys.exit()

			###loss1 = lossFn(pred, y)
			####loss2 = lossFn(sumPred,torch.zeros_like(sumPred))
			###loss = loss1#+loss2*1e-2

			#pred = pred.type(torch.LongTensor)
			#@@ loss function selection
			if lossFn == 'crossEntropy':	 #cross entropy loss
				y = y.type(torch.FloatTensor)
				y = y.to(device.device)
				y = y.to(torch.long)
				#y = y.reshape(-1,1)
				#y = y.to(torch.long)
				#pred = pred.reshape(-1,1,pred.shape[-1])
				if 0:
					print(f"pred: {pred[0]}")
					print(f"y: {y[0]}")
					print(f"pred.shape: {pred.shape}")
					print(f"y.shape: {y.shape}")
					sys.exit()
				loss = torch.nn.functional.cross_entropy(pred,y)
			elif lossFn == 'manualCrossEntropy': # manual cross entropy loss
				pred_norm = pred - torch.max(pred,dim=1)[0].reshape(-1,1)
				hypo = torch.nn.functional.softmax(pred_norm,dim=1)
				loss = (y*-torch.log(hypo)).sum(dim=1).mean()
				#print(f"pred_norm: {pred_norm}")
				#print(f"hypo: {hypo}")
				#sys.exit()
			elif lossFn == 'mse': # MSE loss
				loss = torch.nn.functional.mse_loss(pred,y)

				if 0:
					print(f"pred.shape: {pred.shape}")
					print(f"y.shape: {y.shape}")
					print(f"pred: {pred}")
					print(f"y: {y}")
					print(f"loss: {loss}")
					print(f"loss.shape: {loss.shape}")
					sys.exit()
				if 0:
					if (batchIdx==0):
						print(f"pred: {pred[0]}")
						print(f"y: {y[0]}")

			#loss = loss1
			#loss = lossFn(pred, y) 
			#print (f"loss: {loss}")

			opt.zero_grad()
			loss.backward()
			opt.step()
			if hasattr(self.neuralEQ, 'detachA'):
				self.neuralEQ.detachA()

			#predClone = pred.clone().detach()
			#predClone = predClone.numpy()
			#predClone = torch.where(predClone>0.5, 1, 0)
			#predClone = torch.tensor(predClone)
			#print (type(predClone))

			####be += sum(sum(sum(abs(torch.sign(pred) - torch.sign(y))/2)))
			#sys.exit()

			#if (batchIdx % 1000 == 0):
			#	loss, current = loss.item(), batchIdx
			#	lossList.append(loss)
			#	batchIdxList.append(batchIdx)
			#	print(f"trainloss: {loss:>7f} [{current:>5d}/{size:>5d}]")
		#print(f'dataSize = {self.dataSize}')
		#print(f'ber = {be}')
		####ber = float(be)/self.dataSize
		#print (pred[0])
		#print (np.sign(pred[0].detach().numpy()))
		#print (y[0])
		#for name, param in self.neuralEQ.named_parameters():
		#	print(f"name: {name} params:\n{param}")

		if (plot):
			plt.figure(0)
			plt.plot(batchIdxList, lossList)
			plt.grid(True)
			plt.yscale('log')
			plt.ylim([0.000000001, 1])
			#plt.show()
			plt.savefig('train.png')

		return loss#, ber


	def evalNeuralEQ(self, lossFn, inSize=8, outSize=4, batchSize=64, delay=0, rxDataTestNew=None, txDataTestNew=None ):
		'''
		Description:
			Evaluating neural EQ with given params.
			Firstly, it generate data sets using 'curatinigData'
			Data sets are feed to neural network for evaluating
		Params:
			lossFn(str)		: Not used
			inSize(int)		: Neural network input size
			outSize(int)	: Neural network output size
			batchSize(int)	: Mini-batch size
			delay(int)		: Offset index when pairing rxData with txData.
			rxDataTestNew(int) : override rxDataTest from class construnction.
		'''
	
		#@@ Turn on neural net eval mode
		self.neuralEQ.eval()

		#@@ Override if rxDataTesNew is set.
		if (rxDataTestNew is not None):
			(rxDataSet, txDataSet) = self.curatingData(rxDataTestNew, txDataTestNew, inSize, outSize, batchSize, delay=delay)
			datalen = len(txDataTestNew)

		else:
			(rxDataSet, txDataSet) = self.curatingData(self.rxDataTest, self.txDataTest, inSize, outSize, batchSize, delay=delay)
			datalen = self.dataSize

		size = len(rxDataSet)

		testLoss, be = 0, 0

		#print(f"rxDataSet: {rxDataSet}")
		#print(f"batchSize: {batchSize}")
		#@@ Evaluation for given data sets
		with torch.no_grad():
			for batchIdx, tmp in enumerate(rxDataSet):
				x = rxDataSet[batchIdx]
				y = txDataSet[batchIdx]
				x = x.to(device.device)
				y = y.to(device.device)

				pred = self.neuralEQ(x)
				#pred = pred.reshape(batchSize,1,-1)	# seperate ver.
				#pred=torch.flatten(pred)
				#y=torch.flatten(y)
				#testLoss += lossFn(pred, y).item()
				#print(f"x: {x}\ny: {y}\npred: {pred}")
				#predClone = pred.clone().detach()
				#predClone = predClone.numpy()
				#predClone = torch.where(pred>0.5, 1, 0)
				#predClone = torch.tensor(predClone)
				#print (type(predClone))


				#print (pred)
				#print (y)
				#print (np.sign(pred) - np.sign(y))
				#print (sum(sum(sum(abs(np.sign(pred) - np.sign(y))))))
				#if (batchIdx==size-1):
				#	print (pred)

				#be += sum(sum(sum(abs(torch.sign(pred) - torch.sign(y))/2)))
				if pred.shape[-1] == 1: # for oneOutput mode for nrz
					decOut = torch.where(pred>0,1,0)
				else:
					decOut = self.modNum-1-torch.argmax(pred,axis=1)
				#print(f'pred: {pred}')
				#print(f'decOut: {decOut}')
				#print(f'y: {y}')
				#decOut = decOut.numpy()
				#@@ ber check with inference data 
				#@@ NRZ: y=-1 or 1 // PAM4: y=-1 or -1/3 or 1/3 or 1 // PAM8: ...
				tmp, beTmp = self.berChecker(y, decOut)
				be += beTmp

		#testLoss /= batchIdx
		testLoss = np.NaN
		ber = float(be)/datalen

		#print(f"testloss: {testLoss:>8f}, bit err: {be}, ber: {ber}\n")

		return testLoss, ber


if __name__ == '__main__':
	nEQ = neuralEQ(inSize=4, outSize=1)
	simNeq = simNeuralEQ(txDataTrain=np.arange(10), rxDataTrain=np.arange(10), txDataTest=np.arange(10), rxDataTest=np.arange(10),neuralEQ=nEQ)
	a,b = simNeq.curatingData(rxData=np.arange(10), txData=np.array([[0,1]]*10), inSize=4,outSize=1,batchSize=1)
	print(f'rxDataSet: {a}')
	print(f'txDataSet: {b}')
